{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be64308",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import pandas as pd\n",
    "import rdkit\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import Draw\n",
    "import deepchem\n",
    "\n",
    "from pytorch_model_summary import summary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4aa86ff",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how an autoregressive model (ARM) works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d125d7bd",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bbe116b",
   "metadata": {},
   "source": [
    "This dataset is a slight modification of a widely used benchmark Tox21 (more [here](https://deepchem.readthedocs.io/en/latest/api_reference/moleculenet.html)). Each molecule is originally represented as SMILES and then tokenized (we used [`BasicSmileTokenizer`](https://deepchem.readthedocs.io/en/latest/api_reference/featurizers.html?highlight=tokenizer#basicsmilestokenizer) from [`DeepChem`](https://deepchem.readthedocs.io/en/latest/index.html)). We limit ourselves to SMILES no longer than 100 tokens and no smaller than 10 tokens. The resulting data is a matrix $7410 \\times 100$ containing integers. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3311cfb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tox21(Dataset):\n",
    "    \"\"\"A filtered version of DeepChem Tox21 dataset. (Only SMILES longer than 9 and shorter than 101, the number of token values: 112)\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        smiles = np.load(os.path.join('molecules', 'tox21_smiles_tokenized.npy'), allow_pickle=True)\n",
    "        smiles = torch.from_numpy(smiles).long()\n",
    "        \n",
    "        if mode == 'train':\n",
    "            self.data = smiles[:6600]\n",
    "        elif mode == 'val':\n",
    "            self.data = smiles[6600:7000]\n",
    "        else:\n",
    "            self.data = smiles[7000:]\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf13336a",
   "metadata": {},
   "source": [
    "## Transformers: A combination of Multi-head Self-Attention, LayerNormalization and MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c6750b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadSelfAttention(nn.Module):\n",
    "    def __init__(self, num_emb, num_heads=8):\n",
    "        super().__init__()\n",
    "        \n",
    "        # hyperparams\n",
    "        self.D = num_emb\n",
    "        self.H = num_heads\n",
    "        \n",
    "        # weights for self-attention\n",
    "        self.w_k = nn.Linear(self.D, self.D * self.H)\n",
    "        self.w_q = nn.Linear(self.D, self.D * self.H)\n",
    "        self.w_v = nn.Linear(self.D, self.D * self.H)\n",
    "        \n",
    "        # weights for a combination of multiple heads\n",
    "        self.w_c = nn.Linear(self.D * self.H, self.D)\n",
    "            \n",
    "    def forward(self, x, causal=True):\n",
    "        # x: B(atch) x T(okens) x D(imensionality)\n",
    "        B, T, D = x.size()\n",
    "        \n",
    "        # keys, queries, values\n",
    "        k = self.w_k(x).view(B, T, self.H, D) # B x T x H x D\n",
    "        q = self.w_q(x).view(B, T, self.H, D) # B x T x H x D\n",
    "        v = self.w_v(x).view(B, T, self.H, D) # B x T x H x D\n",
    "        \n",
    "        k = k.transpose(1, 2).contiguous().view(B * self.H, T, D) # B*H x T x D\n",
    "        q = q.transpose(1, 2).contiguous().view(B * self.H, T, D) # B*H x T x D\n",
    "        v = v.transpose(1, 2).contiguous().view(B * self.H, T, D) # B*H x T x D\n",
    "        \n",
    "        k = k / (D**0.25) # scaling \n",
    "        q = q / (D**0.25) # scaling\n",
    "        \n",
    "        # kq\n",
    "        kq = torch.bmm(q, k.transpose(1, 2)) # B*H x T x T\n",
    "        \n",
    "        # if causal\n",
    "        if causal:\n",
    "            mask = torch.triu_indices(T, T, offset=1)\n",
    "            kq[..., mask[0], mask[1]] = float('-inf')\n",
    "        \n",
    "        # softmax\n",
    "        skq = F.softmax(kq, dim=2)\n",
    "        \n",
    "        # self-attention\n",
    "        sa = torch.bmm(skq, v) # B*H x T x D\n",
    "        sa = sa.view(B, self.H, T, D) # B x H x T x D\n",
    "        sa = sa.transpose(1, 2) # B x T x H x D\n",
    "        sa = sa.contiguous().view(B, T, D * self.H) # B x T x D*H\n",
    "        \n",
    "        out = self.w_c(sa) # B x T x D\n",
    "        \n",
    "        return out      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c708cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, num_emb, num_neurons, num_heads=4):\n",
    "        super().__init__()\n",
    "        \n",
    "        # hyperparams\n",
    "        self.D = num_emb\n",
    "        self.H = num_heads\n",
    "        self.neurons = num_neurons\n",
    "        \n",
    "        # components\n",
    "        self.msha = MultiHeadSelfAttention(num_emb=self.D, num_heads=self.H)\n",
    "        self.layer_norm1 = nn.LayerNorm(self.D)\n",
    "        self.layer_norm2 = nn.LayerNorm(self.D)\n",
    "        \n",
    "        self.mlp = nn.Sequential(nn.Linear(self.D, self.neurons * self.D),\n",
    "                                nn.GELU(),\n",
    "                                nn.Linear(self.neurons * self.D, self.D))\n",
    "    \n",
    "    def forward(self, x, causal=True):\n",
    "        # Multi-Head Self-Attention\n",
    "        x_attn = self.msha(x, causal)\n",
    "        # LayerNorm\n",
    "        x = self.layer_norm1(x_attn + x)\n",
    "        # MLP\n",
    "        x_mlp = self.mlp(x)\n",
    "        # LayerNorm\n",
    "        x = self.layer_norm2(x_mlp + x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7105f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LossFun(nn.Module):\n",
    "    def __init__(self,):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.loss = nn.NLLLoss(reduction='none')\n",
    "    \n",
    "    def forward(self, y_model, y_true, reduction='sum'):\n",
    "        # y_model: B(atch) x T(okens) x V(alues)\n",
    "        # y_true: B x T      \n",
    "        B, T, V = y_model.size()\n",
    "        \n",
    "        y_model = y_model.view(B * T, V)\n",
    "        y_true = y_true.view(B * T,)\n",
    "        \n",
    "        loss_matrix = self.loss(y_model, y_true) # B*T\n",
    "        \n",
    "        if reduction == 'sum':\n",
    "            return torch.sum(loss_matrix)\n",
    "        elif reduction == 'mean':\n",
    "            loss_matrix = loss_matrix.view(B, T)\n",
    "            return torch.mean(torch.sum(loss_matrix, 1))\n",
    "        else:\n",
    "            raise ValueError('Reduction could be either `sum` or `mean`.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "561fbd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Transformer(nn.Module):\n",
    "    def __init__(self, num_tokens, num_token_vals, num_emb, num_neurons, num_heads=2, dropout_prob=0.1, num_blocks=10, device='cpu'):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Remember, always credit the author, even if it's you ;)\n",
    "        print('Transformer by JT.')\n",
    "        \n",
    "        # hyperparams\n",
    "        self.device = device\n",
    "        self.num_tokens = num_tokens\n",
    "        self.num_token_vals = num_token_vals\n",
    "        self.num_emb = num_emb\n",
    "        self.num_blocks = num_blocks\n",
    "        \n",
    "        # embedding layer\n",
    "        self.embedding = torch.nn.Embedding(num_token_vals, num_emb)\n",
    "        \n",
    "        # positional embedding\n",
    "        self.positional_embedding = nn.Embedding(num_tokens, num_emb)\n",
    "        \n",
    "        # transformer blocks\n",
    "        self.transformer_blocks = nn.ModuleList()\n",
    "        for _ in range(num_blocks):\n",
    "            self.transformer_blocks.append(TransformerBlock(num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads))\n",
    "        \n",
    "        # output layer (logits + softmax)\n",
    "        self.logits = nn.Sequential(nn.Linear(num_emb, num_token_vals))\n",
    "        \n",
    "        # dropout layer\n",
    "        self.dropout = nn.Dropout(dropout_prob)\n",
    "        \n",
    "        # loss function\n",
    "        self.loss_fun = LossFun()\n",
    "        \n",
    "    def transformer_forward(self, x, causal=True, temperature=1.0):\n",
    "        # x: B(atch) x T(okens)\n",
    "        # embedding of tokens\n",
    "        x = self.embedding(x) # B x T x D\n",
    "        # embedding of positions\n",
    "        pos = torch.arange(0, x.shape[1], dtype=torch.long).unsqueeze(0).to(self.device)\n",
    "        pos_emb = self.positional_embedding(pos)\n",
    "        # dropout of embedding of inputs\n",
    "        x = self.dropout(x + pos_emb)\n",
    "        \n",
    "        # transformer blocks\n",
    "        for i in range(self.num_blocks):\n",
    "            x = self.transformer_blocks[i](x)\n",
    "        \n",
    "        # output logits\n",
    "        out = self.logits(x)\n",
    "        \n",
    "        return F.log_softmax(out/temperature, 2)\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def sample(self, batch_size=4, temperature=1.0):\n",
    "        x_seq = np.asarray([[self.num_token_vals - 1] for i in range(batch_size)])\n",
    "\n",
    "        # sample next tokens\n",
    "        for i in range(self.num_tokens-1):\n",
    "            xx = torch.tensor(x_seq, dtype=torch.long, device=self.device)\n",
    "            # process x and calculate log_softmax\n",
    "            x_log_probs = self.transformer_forward(xx, temperature=temperature)\n",
    "            # sample i-th tokens\n",
    "            x_i_sample = torch.multinomial(torch.exp(x_log_probs[:,i]), 1).to(self.device)\n",
    "            # update the batch with new samples\n",
    "            x_seq = np.concatenate((x_seq, x_i_sample.to('cpu').detach().numpy()), 1)\n",
    "        \n",
    "        return x_seq\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def top1_rec(self, x, causal=True):\n",
    "        x_prob = torch.exp(self.transformer_forward(x, causal=True))[:,:-1,:].contiguous()\n",
    "        _, x_rec_max = torch.max(x_prob, dim=2)\n",
    "        return torch.sum(torch.mean((x_rec_max.float() == x[:,1:].float().to(device)).float(), 1).float())\n",
    "    \n",
    "    def forward(self, x, causal=True, temperature=1.0, reduction='mean'):\n",
    "        # get log-probabilities\n",
    "        log_prob = self.transformer_forward(x, causal=causal, temperature=temperature)\n",
    "        \n",
    "        return self.loss_fun(log_prob[:,:-1].contiguous(), x[:,1:].contiguous(), reduction=reduction)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e863209e",
   "metadata": {},
   "source": [
    "### Auxiliary functions: training, evaluation, plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f679426d",
   "metadata": {},
   "source": [
    "It's rather self-explanatory, isn't it?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e53a0c3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None, device='cuda'):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model').to(device)\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    rec = 1.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = model_best.forward(test_batch.to(device), reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        \n",
    "        rec_t = model_best.top1_rec(test_batch.to(device))\n",
    "        rec = rec + rec_t.item()\n",
    "        \n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "    rec = rec / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}, rec={rec}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}, val rec={rec}')\n",
    "\n",
    "    return loss, rec\n",
    "\n",
    "def plot_curve(name, nll_val, ylabel='nll'):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.savefig(name + '_' + ylabel + '_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3903be19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader, device='cuda'):\n",
    "    nll_val = []\n",
    "    rec_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            loss = model.forward(batch.to(device))\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val, r_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "        rec_val.append(r_val)\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "    rec_val = np.asarray(rec_val)\n",
    "\n",
    "    return nll_val, rec_val"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48de0294",
   "metadata": {},
   "source": [
    "### Initialize dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1941c317",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = Tox21(mode='train')\n",
    "val_data = Tox21(mode='val')\n",
    "test_data = Tox21(mode='test')\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)\n",
    "\n",
    "result_dir = 'results/'\n",
    "if not(os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)\n",
    "name = 'transformer_gen'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f89e638",
   "metadata": {},
   "source": [
    "### Hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2c380bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_tokens = 101\n",
    "num_token_vals = 112\n",
    "num_emb = 64\n",
    "num_neurons=4\n",
    "num_heads=4\n",
    "num_blocks=10\n",
    "causal=True\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 200 # max. number of epochs\n",
    "max_patience = 10 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "else:\n",
    "    device = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dc57fb3",
   "metadata": {},
   "source": [
    "### Initialize Transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94c83451",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Transformer(num_tokens=num_tokens, num_token_vals=num_token_vals, num_emb=num_emb, num_neurons=num_neurons, num_heads=num_heads, num_blocks=num_blocks, device=device)\n",
    "model = model.to(device)\n",
    "# Print the summary (like in Keras)\n",
    "print(summary(model, torch.zeros(1, num_tokens, dtype=torch.long).to(device), show_input=False, show_hierarchical=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81da0195",
   "metadata": {},
   "source": [
    "### Let's play! Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "712c07ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d03d871",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val, rec_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                           training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b62aa1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss, test_rec = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "\n",
    "with open(result_dir + name + '_test_loss.txt', \"w\") as f:\n",
    "    f.write('Test NLL: ' + str(test_loss)+'\\n'+'Test REC: ' + str(test_rec))\n",
    "    f.close()\n",
    "\n",
    "plot_curve(result_dir + name, nll_val, ylabel='nll')\n",
    "plot_curve(result_dir + name, rec_val, ylabel='rec')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1686db7",
   "metadata": {},
   "source": [
    "### Data visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4f916aa",
   "metadata": {},
   "source": [
    "#### Auxiliary functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d2a67fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_valid_smiles(smiles):\n",
    "    \"\"\" Using RDKit to calculate whether molecule is syntactically and semantically valid.\n",
    "    \"\"\"\n",
    "    if smiles == \"\":\n",
    "        return False\n",
    "    try:\n",
    "        return Chem.MolFromSmiles(smiles, sanitize=True) is not None\n",
    "    except:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfecfbfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_smiles(smiles, nrows, ncols, path=None):\n",
    "\n",
    "    if len(smiles) < nrows * ncols:\n",
    "        raise AssertionError(\"Provide more examples\")\n",
    "\n",
    "    fig, axes = plt.subplots(nrows, ncols)\n",
    "    plt.subplots_adjust(wspace=0., hspace=0.)\n",
    "    idx = 0\n",
    "    for row in [idx for idx in range(nrows)]:\n",
    "        for column in [idx for idx in range(ncols)]:\n",
    "            ax = axes[row, column]\n",
    "            # ax.set_title(f\"Image ({row}, {column})\")\n",
    "            ax.axis('off')\n",
    "            ax.imshow(Draw.MolToImage(Chem.MolFromSmiles(smiles[idx])))\n",
    "            idx += 1\n",
    "    if path:\n",
    "        plt.savefig(path, dpi=300, bbox_inches='tight')\n",
    "        plt.tight_layout()\n",
    "    else:\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2be99f99",
   "metadata": {},
   "source": [
    "#### Sample new data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "440eabf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample molecules\n",
    "model_best = torch.load(result_dir + name + '.model')\n",
    "model_best = model_best.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "001bd9de",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = 1028\n",
    "x_sample = model_best.sample(batch_size=num_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20e054fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "alphabet_dict_reverse = np.load(os.path.join('molecules', 'alphabet_dict_reverse.npy'), allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eac5214",
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles = []\n",
    "for n in range(x_sample.shape[0]):\n",
    "    s = ''\n",
    "    for i in range(1, x_sample.shape[1]):\n",
    "        c = alphabet_dict_reverse[x_sample[n, i]]\n",
    "        if c == 'unk':\n",
    "            break\n",
    "        else:\n",
    "            s = s + c\n",
    "    if is_valid_smiles(s):\n",
    "        smiles.append(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92ed3155",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(f'The percentage of valid molecules: {len(smiles)/num_samples}')\n",
    "with open(result_dir + name + '_validity.txt', \"w\") as f:\n",
    "    f.write(f'The percentage of valid molecules: {len(smiles)/num_samples}')\n",
    "    f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a77eaa6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_smiles(smiles, 6, 6, path=result_dir + 'transformer_generated_molecules.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
